{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a6e9ee6f",
   "metadata": {},
   "source": [
    "# Example: Analyze the Performance of Surrogates\n",
    "\n",
    "In this notebook, we want to analyze how different surrogates perform.\n",
    "\n",
    "For this, we first need to evaluate the accuracy of the surrogates' predictions. To do this, we can use the `tsbench` CLI:\n",
    "\n",
    "```bash\n",
    "tsbench analysis surrogate \\\n",
    "    --experiment surrogate-analysis \\\n",
    "    --config_path configs/analysis/surrogates.yaml\n",
    "```\n",
    "\n",
    "This command runs a hyperparameter search over the surrogate configurations stored in `configs/analysis/surrogates.yaml` and stores the resulting metrics in MongoDB using Sacred. We can later retrieve these metrics by querying the `surrogate-analysis` experiment in MongoDB.\n",
    "\n",
    "_**Note 1:** The above comment takes quite some time to run since it only terminates once all configurations have been evaluated. Consider running it in a `tmux` session._\n",
    "\n",
    "_**Note 2:** If you want to evaluate surrogate model performance when dataset meta features are used (default for `configs/analysis/surrogates.yaml`), they need to be precomputed. For this, run `tsbench datasets compute-stats` and `tsbench datasets compute-catch22`. The latter command takes some time (up to 2 hours with 48 CPUs) and requires plenty of memory (up to 128 GiB) due to a memory leak in the catch22 library. Consider passing the `--dataset` option to the command and run it for individual datasets if you do not have that much memory available._"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33774ffe",
   "metadata": {},
   "source": [
    "## Load the Metrics from MongoDB\n",
    "\n",
    "First, we want to load the performance of the individual surrogates from MongoDB."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "615c6dd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsbench.analysis.tracking import SacredMongoClient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ddd72bfb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "client = SacredMongoClient(\"surrogate-analysis\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e1ee341a",
   "metadata": {},
   "outputs": [],
   "source": [
    "configs = []\n",
    "metrics = []\n",
    "for experiment in client:\n",
    "    # Load the metrics file which contains the performances for each left-out dataset\n",
    "    df = experiment.read_parquet(\"results.parquet\")\n",
    "    configs.append(experiment.config)\n",
    "    # Average the performance across all test datasets\n",
    "    metrics.append(df.mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4faaf678",
   "metadata": {},
   "source": [
    "## Build a Data Frame\n",
    "\n",
    "After we have loaded the metrics, we can aggregate them into a dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e9f25bfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "\n",
    "def flatten(data):\n",
    "    result = {}\n",
    "    for k, v in data.items():\n",
    "        if isinstance(v, dict):\n",
    "            result.update({f\"{k}_{kk}\": vv for kk, vv in flatten(v).items()})\n",
    "        else:\n",
    "            result[k] = str(v)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "92105dbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "index_df = pd.DataFrame([flatten(c) for c in configs])\n",
    "df = pd.DataFrame(\n",
    "    metrics, index=pd.MultiIndex.from_frame(index_df)\n",
    ").sort_index()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9e158ab",
   "metadata": {},
   "source": [
    "## Report the Performance\n",
    "\n",
    "Eventually, we can group them by the type of surrogate (since multiple random evaluations are possible run per\n",
    "surrogate) and compute their average performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fcec02ee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "latency_mean  mrr            NaN\n",
       "              ndcg           NaN\n",
       "              nrmse          NaN\n",
       "              precision_10   NaN\n",
       "              precision_20   NaN\n",
       "              precision_5    NaN\n",
       "              smape          NaN\n",
       "ncrps_mean    mrr            NaN\n",
       "              ndcg           NaN\n",
       "              nrmse          NaN\n",
       "              precision_10   NaN\n",
       "              precision_20   NaN\n",
       "              precision_5    NaN\n",
       "              smape          NaN\n",
       "dtype: float64"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.query(\n",
    "    \"surrogate == 'mlp' and \"\n",
    "    \"mlp_objective == 'ranking' and \"\n",
    "    \"mlp_discount == 'linear' and \"\n",
    "    \"inputs_use_simple_dataset_features == 'False'\"\n",
    ").mean().sort_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb141b96",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
